{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import r2_score\n",
    "\n",
    "class LogisticRegression:\n",
    "    \n",
    "    def __init__(self):\n",
    "        \"\"\"初始化Linear Regression模型\"\"\"\n",
    "        self.coef_ = None\n",
    "        self.interception_ = None\n",
    "        self._theta = None\n",
    "    \n",
    "    def _sigmoid(t):\n",
    "            return 1. / (1. + np.exp(-t))\n",
    "        \n",
    "    def fit(self, X_train, y_train, eta=0.01, n_iters = 1e4):\n",
    "        \"\"\"根据训练数据集X_train, y_train，使用梯度下降法训练Linear Regression模型\"\"\"\n",
    "        assert X_train.shape[0] == y_train.shape[0], \"the size of X_train must be equal to the size of y_train\"\n",
    "        \n",
    "\n",
    "        \n",
    "        def J(theta, X_b, y):\n",
    "            y_hat = self._sigmoid(X_b.dot(theta))\n",
    "            try:\n",
    "                return np.sum(y*np.log(y_hat) + (1-y)*np.log(1-y_hat))\n",
    "            except:\n",
    "                return float('inf')\n",
    "            \n",
    "        def dJ(theta, X_b, y):\n",
    "            return X_b.T.dot(self._sigmoid(X_b.dot(theta))-y) / len(X_b)\n",
    "        \n",
    "        def gradient_descent(X_b, y, initial_theta, eta, n_iters = 1e4, epsilon=1e-8):\n",
    "            theta = initial_theta\n",
    "            i_iter = 0\n",
    "            while i_iter < n_iters:\n",
    "                gradient = dJ(theta, X_b, y)\n",
    "                last_theta = theta\n",
    "                theta = theta - eta * gradient\n",
    "                if (abs(J(theta, X_b, y) - J(last_theta, X_b, y)) < epsilon):\n",
    "                    break\n",
    "                i_iter += 1\n",
    "            return theta\n",
    "        \n",
    "        X_b = np.hstack([np.ones((len(X_train), 1)), X_train])\n",
    "        initial_theta = np.zeros(X_b.shape[1])\n",
    "        self._theta = gradient_descent(X_b, y_train, initial_theta, eta)\n",
    "        self.interception_ = self._theta[0]\n",
    "        self.coef_ = self._theta[1:]\n",
    "        return self\n",
    "        \n",
    "    def predict_proba(self, X_predict):\n",
    "        \"\"\"给定待预测数据集X_predict，返回表示X_predict的结果向量\"\"\"\n",
    "        assert self.interception_ is not None and self.coef_ is not None, \"must fit before predict\"\n",
    "        assert X_predict.shape[1] == len(self.coef_), \"the feature number of X_predict must equal to X_train\"\n",
    "        \n",
    "        X_b = np.hstack([np.ones((len(X_predict), 1)), X_predict])\n",
    "        return self._sigmoid(X_b.dot(self._theta))\n",
    "    \n",
    "    def predict(self, X_predict):\n",
    "        \"\"\"给定待预测数据集X_predict，返回表示X_predict的结果向量\"\"\"\n",
    "        assert self.interception_ is not None and self.coef_ is not None, \"must fit before predict\"\n",
    "        assert X_predict.shape[1] == len(self.coef_), \"the feature number of X_predict must equal to X_train\"\n",
    "        \n",
    "        proba = self.predict_proba(X_predict)\n",
    "        return np.array(proba>=0.5, dtype=int)\n",
    "    \n",
    "    def score(self, X_test, y_test):\n",
    "        \"\"\"根据测试数据集X_test, y_test确定当前模型的准确度\"\"\"\n",
    "        \n",
    "        y_predict = self.predict(X_test)\n",
    "        return r2_score(y_test, y_predict)\n",
    "        \n",
    "    def __repr__(self):\n",
    "        return \"LogisticRegression()\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 实现逻辑回归"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "iris = datasets.load_iris()\n",
    "X = iris.data\n",
    "y = iris.target\n",
    "X = X[y<2, :2]  # 逻辑回归只能解决二分类问题，因此只选取其中两种花的数据\n",
    "y = y[y<2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100, 2)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100,)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD5CAYAAAA3Os7hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAXn0lEQVR4nO3dfYxcV3nH8d/j2RQwASI1qxLFL6sKhNREzotXISgVCjiteLGMKkBytZQaFbm1SQlqK9oQKRKWUFVValNI7WgJqkLjltDwUoNSWhKIGlRhtA4hEEwrg+LEhDabpCSlBirbT/+4s/Ls7MzsPTNz5p5z5vuRrnbmzsnd59x798n1uc89Y+4uAED+NjQdAABgPEjoAFAIEjoAFIKEDgCFIKEDQCFI6ABQiJm6Dc2sJWlJ0g/dfWfXZ3sk/bmkH7ZX3e7udw7a3sUXX+xzc3NBwQLAtDt27Ngz7j7b67PaCV3STZKOS3p5n8/vcfcb625sbm5OS0tLAb8eAGBmJ/t9VmvIxcw2SXqrpIFX3QCA5tQdQ79N0gclnRvQ5u1m9qiZ3Wtmm3s1MLO9ZrZkZkvLy8uhsQIABlg3oZvZTklPu/uxAc2+IGnO3bdJul/SXb0aufuiu8+7+/zsbM8hIADAkOpcoV8naZeZPS7pU5LeaGZ3dzZw92fd/efttx+XtH2sUQIA1rVuQnf3m919k7vPSdot6Svu/q7ONmZ2ScfbXapungIAJiikymUVMzsgacndj0h6v5ntknRG0nOS9ownPABAXUEPFrn7gys16O5+azuZr1zFX+buV7j7G9z9ezGCBRpx+LA0Nydt2FD9PHy46YiAnoa+QgemwuHD0t690unT1fuTJ6v3krSw0FxcQA88+g8Mcsst55P5itOnq/VAYkjowCBPPBG2HmgQCR0YZMuWsPVAg0jowCAf+Yi0cePqdRs3VuuBxJDQgUEWFqTFRWnrVsms+rm4yA1RJIkqF2A9CwskcGSBK3QAKAQJHQAKQUIHgEKQ0AGgECR0ACgECR0ACkFCB4BCkNABoBAkdAAoBAkd5eCLKDDlePQfZeCLKACu0FEIvogCIKGjEHwRBUBCRyH4IgqAhI5C8EUUAAkdheCLKACqXFAQvogCU44rdIyO+m8gCVyhYzTUfwPJ4Aodo6H+G0gGCR2jof4bSAYJHaOh/htIBgkdo6H+G0gGCR2jof4bSAZVLhgd9d9AEmpfoZtZy8y+aWZf7PHZi8zsHjM7YWZHzWxunEEC2aAmHw0KGXK5SdLxPp/9jqT/dvdXSfpLSX82amBAdlZq8k+elNzP1+ST1DEhtRK6mW2S9FZJd/Zp8jZJd7Vf3ytph5nZ6OEBGaEmHw2re4V+m6QPSjrX5/NLJT0pSe5+RtLzkn6xu5GZ7TWzJTNbWl5eHiJcIGHU5KNh6yZ0M9sp6Wl3PzaoWY91vmaF+6K7z7v7/OzsbECYQAaoyUfD6lyhXydpl5k9LulTkt5oZnd3tTklabMkmdmMpFdIem6McQLpoyYfDVs3obv7ze6+yd3nJO2W9BV3f1dXsyOSfrv9+h3tNmuu0IGiUZOPhg1dh25mByQtufsRSZ+Q9LdmdkLVlfnuMcUH5IWafDQo6ElRd3/Q3Xe2X9/aTuZy95+5+zvd/VXufo27/yBGsJgy+/dLMzPV1e7MTPUeQF88KYo07d8vHTp0/v3Zs+ffHzzYTExA4pjLBWlaXAxbD4CEjkSdPRu2HgAJHYlqtcLWAyChI1Er30tadz0AbooiUSs3PhcXq2GWVqtK5twQBfoioSNdBw+SwIEADLmgtxtuqOq/V5Ybbmg6ouYwxzkyQULHWjfcID3wwOp1DzwwnUmdOc6REWtqypX5+XlfWlpq5HdjHYOmsp+2KXrm5qok3m3rVunxxycdDSAzO+bu870+4wodGIQ5zpEREjowCHOcIyMkdKy1Y0fY+pIxxzkyQkLHWvffvzZ579hRrZ82zHGOjHBTFAAywk1RhItVex2yXeq/gSA8KYq1VmqvT5+u3q/UXkujDTWEbDdWDEDBGHLBWrFqr0O2S/030BNDLggTq/Y6ZLvUfwPBSOhYK1btdch2qf8GgpHQsVas2uuQ7VL/DQQjoWOtWLXXIdul/hsIxk1RAMgIN0VjSKFGOjSGFGIGEA116MNIoUY6NIYUYgYQFUMuw0ihRjo0hhRiBjAyhlzGLYUa6dAYUogZQFQk9GGkUCMdGkMKMQOIioQ+jBRqpENjSCFmAFGR0IeRQo10aAwpxAwgKm6KAkBGRropamYvNrNvmNm3zOwxM/twjzZ7zGzZzB5pL+8dR+AYs/37pZmZ6gp9ZqZ6P462qdS3pxIH0BR3H7hIMkkXtl9fIOmopGu72uyRdPt62+pctm/f7pigffvcpbXLvn2jtb37bveNG1e327ixWj9JqcQBRCZpyfvk1aAhFzPbKOlrkva5+9GO9Xskzbv7jXW3xZDLhM3MSGfPrl3faklnzgzfNpX69lTiACIbuQ7dzFpm9oikpyV9uTOZd3i7mT1qZvea2eY+29lrZktmtrS8vFy7AxiDXgm63/qQtqnUt6cSB9CgWgnd3c+6+5WSNkm6xswu72ryBUlz7r5N0v2S7uqznUV3n3f3+dnZ2VHiRqhWq/76kLap1LenEgfQoKCyRXf/saQHJb2pa/2z7v7z9tuPS9o+lugwPivzttRZH9I2lfr2VOIAmtRvcH1lkTQr6aL265dIekjSzq42l3S8/g1JX19vu9wUbcC+fe6tVnXDsNXqfZNzmLZ33+2+dau7WfWzqRuRqcQBRKRRboqa2TZVQygtVVf0n3b3A2Z2oL3hI2b2p5J2SToj6TlVN02/N2i73BQFgHAj3RR190fd/Sp33+bul7v7gfb6W939SPv1ze5+mbtf4e5vWC+ZFyFWzXNI/XfMbYf0L8d9kRlK7FFLv0v32EvWQy6xap5D6r9jbjukfznui8xQYo9OGlcd+jhlPeQSq+Y5pP475rZD+pfjvsgMJfboNGjIhYQ+jA0bqgulbmbSuXPDb9es/2ejHqeQbYf0L8d9kZlYuxh54gsuxi1WzXNI/XfMbYf0L8d9kRlK7FEXCX0YsWqeQ+q/Y247pH857ovMUGKP2voNrsdesr4p6h6v5jmk/jvmtkP6l+O+yAwl9lghbooCQBkYQ0clhdpyZI3TIm0zTQeACTl8uBp/Pn26en/y5Pnx6O6voQtpi6nBaZE+hlymRQq15cgap0UaGHJB2HzhzC2OHjgt0kdCnxYp1JYja5wW6SOhT4sUasuRNU6L9JHQp8XCgrS4WA14mlU/Fxd7380KaYupwWmRPm6KAkBGpvumaKzC2ZDtpjKvN0XESSn9cJTevxAT2xf9HiGNvUzk0f9YE0mHbDeVeb2ZVDsppR+O0vsXYtz7QlP76H+swtmQ7aYyrzdFxEkp/XCU3r8Q494X0zsfeqyJpEO2m8q83kyqnZTSD0fp/Qsx7n0xvWPosQpnQ7abyrzeFBEnpfTDUXr/QkxyX5Sd0GMVzoZsN5V5vSkiTkrph6P0/oWY6L7oN7gee5nYfOixJpIO2W4q83ozqXZSSj8cpfcvxDj3hab2pigAFGZ6x9Bjor4dyEKsP5Mk6+z7XbrHXrL+Cjrq24EsxPozabLOXgy5jBn17UAWYv2ZNFlnz5DLuMWaGDpku73O0kHrgSkU688k1bnhSejDoL4dyEKsP5NU6+xJ6MOgvh3IQqw/k2Tr7PsNrsdesr4p6k59O5CJWH8mTdXZi5uiAFCGkW6KmtmLzewbZvYtM3vMzD7co82LzOweMzthZkfNbG70sPsILf5Mslh0gJCi2cL3RcxwY+7mumL2L7NDHaTw0340/S7dVxZJJunC9usLJB2VdG1Xm/2S7mi/3i3pnvW2O9SQS2jxZ26TMocUzRa+L2KGG3M31xWzf5kd6iCFn/a1aMCQS9C4t6SNkh6W9Nqu9f8s6XXt1zOSnlF7at5+y1AJfevW3n+JW7eOp33TVgb6updWa23bwvdFzHBj7ua6YvYvs0MdpPDTvpZBCb3WGLqZtSQdk/QqSX/t7n/c9fl3JL3J3U+133+/nfSf6Wq3V9JeSdqyZcv2k70q8wcJnVg4t0mZQ+ZOL3xfxAw35m6uK2b/MjvUQQo/7WsZ+cEidz/r7ldK2iTpGjO7vPt39PrPemxn0d3n3X1+dna2zq9eLbT4M9Vi0X5CimYL3xcxw425m+uK2b/MDnWQwk/7kQXVobv7jyU9KOlNXR+dkrRZksxsRtIrJD03hvhWCy3+TLZYtI+QotnC90XMcGPu5rpi9i+zQx2k8NN+dP3GYlYWSbOSLmq/fomkhyTt7GrzPq2+Kfrp9bY7dB16aPFnbpMyhxTNFr4vYoYbczfXFbN/mR3qIIWf9uvSKGPoZrZN0l2SWqqu6D/t7gfM7EB7w0fM7MWS/lbSVaquzHe7+w8GbZc6dAAIN2gMfWa9/9jdH1WVqLvX39rx+meS3jlKkACA0ZQ/l8tUPVWAukJOixROoZgP0+T24FQKxyNZ/cZiYi8TmculxKcKMLKQ0yKFUyjmwzS5PTiVwvFomqZ2LpcmZ6FHskJOixROodAYUuhfbtvNyaAx9LITeolPFWBkIadFCqdQzIdpcntwKoXj0bTp/caiaXuqALWEnBYpnEIxH6bJ7cGpFI5HyspO6FP3VAHqCDktUjiFYj5Mk9uDUykcj6T1G1yPvUzsCy5Ke6oAYxFyWqRwCsV8mCa3B6dSOB5N0tTeFAWAwkzvGDowBiFfhpGK3GJOpbY8lTiG1u/SPfaS/XeKYiqEfBlGKnKLOZXa8lTiWI8YcgGGMzMjnT27dn2rJZ05M/l46sgt5lRqy1OJYz0MuQBD6pUYB61PQW4xP/FE2PrS4xgFCR0YIOTLMFKRW8yp1JanEscoSOjAACFfhpGK3GJOpbY8lThG0m9wPfbCTVHkIuTLMFKRW8yp1JanEscg4qYoAJSBm6KIKsfa3Vgxx6r/znEfowH9Lt1jLwy5lCGX2t1OsWKOVf+d4z5GPGLIBbHkUrvbKVbMseq/c9zHiIchF0STY+1urJhj1X/nuI/RDBI6RpJj7W6smGPVf+e4j9EMEjpGkmPtbqyYY9V/57iP0ZB+g+uxF26KliOH2t1usWKOVf+d4z5GHOKmKACUgZuimAqxarVDtku9OJo003QAwDgcPlyNVZ8+Xb0/efL82PXCwmS2GysGoC6GXFCEWLXaIdulXhyTwJALiherVjtku9SLo2kkdBQhVq12yHapF0fTSOgoQqxa7ZDtUi+OppHQUYSFBWlxsRqvNqt+Li6OfjMyZLuxYgDq4qYoAGRkpJuiZrbZzL5qZsfN7DEzu6lHm+vN7Hkze6S93DqOwNGcHOupqRePj/2WuH6PkK4ski6RdHX79csk/YekX+lqc72kL663rc6FR//TleP82yEx59i/FLDf0qBxPvpvZv8o6XZ3/3LHuusl/ZG776y7HYZc0pVjPTX14vGx39IwaMglKKGb2Zykf5V0ubu/0LH+ekmfkXRK0lOqkvtjPf77vZL2StKWLVu2n+x1dqBxGzZU11/dzKRz5yYfTx0hMefYvxSw39IwlgeLzOxCVUn7A53JvO1hSVvd/QpJH5P0+V7bcPdFd5939/nZ2dm6vxoTlmM9NfXi8bHf0lcroZvZBaqS+WF3/2z35+7+grv/pP36PkkXmNnFY40UE5NjPTX14vGx3zLQb3B9ZZFkkj4p6bYBbV6p88M310h6YuV9v4WbomnLcf7tkJhz7F8K2G/N0yg3Rc3sVyU9JOnbklZGyj4kaUv7fwh3mNmNkvZJOiPpp5L+wN3/bdB2uSkKAOFGGkN396+5u7n7Nne/sr3c5+53uPsd7Ta3u/tl7n6Fu1+7XjLHeFATvNr+/dLMTHWTbmameg9ME+ZDzxRzb6+2f7906ND592fPnn9/8GAzMQGTxqP/maImeLWZmSqJd2u1pDNnJh8PEAvzoReIubdX65XMB60HSkRCzxQ1wau1WmHrgRKR0DNFTfBqK/cP6q4HSkRCzxRzb6928KC0b9/5K/JWq3rPDVFME26KAkBGuClaV+GF3YV3r/j+pYB9nLh+j5DGXpJ79L/wyZ4L717x/UsB+zgNGud86OOS3JBL4YXdhXev+P6lgH2chrHNhz5OySX0wid7Lrx7xfcvBezjNDCGXkfhhd2Fd6/4/qWAfZw+EvqKwgu7C+9e8f1LAfs4A/0G12Mvyd0UdS9+sufCu1d8/1LAPm6euCkKAGVgDB3IXMz6b2rLy8F86EDiYs59z7z6ZWHIBUhczPpvasvzw5ALkLGYc98zr35ZSOhA4mLWf1NbXhYSOpC4mPXf1JaXhYQOJC7m3PfMq18WbooCQEa4KQoAU4CEDgCFIKEDQCFI6ABQCBI6ABSChA4AhSChA0AhSOgAUIh1E7qZbTazr5rZcTN7zMxu6tHGzOyjZnbCzB41s6vjhItRMO81ULY686GfkfSH7v6wmb1M0jEz+7K7f7ejzZslvbq9vFbSofZPJIJ5r4HyrXuF7u4/cveH26//R9JxSZd2NXubpE+2v/Lu65IuMrNLxh4thnbLLeeT+YrTp6v1AMoQNIZuZnOSrpJ0tOujSyU92fH+lNYmfZnZXjNbMrOl5eXlsEgxEua9BspXO6Gb2YWSPiPpA+7+QvfHPf6TNbN+ufuiu8+7+/zs7GxYpBgJ814D5auV0M3sAlXJ/LC7f7ZHk1OSNne83yTpqdHDw7gw7zVQvjpVLibpE5KOu/tf9Gl2RNK729Uu10p63t1/NMY4MSLmvQbKV6fK5TpJvyXp22b2SHvdhyRtkSR3v0PSfZLeIumEpNOS3jP+UDGqhQUSOFCydRO6u39NvcfIO9u4pPeNKygAQDieFAWAQpDQAaAQJHQAKAQJHQAKQUIHgEKQ0AGgECR0ACiEVSXkDfxis2VJJxv55eu7WNIzTQcREf3LV8l9k+hfHVvdvedkWI0l9JSZ2ZK7zzcdRyz0L18l902if6NiyAUACkFCB4BCkNB7W2w6gMjoX75K7ptE/0bCGDoAFIIrdAAoBAkdAAox1QndzFpm9k0z+2KPz/aY2bKZPdJe3ttEjKMws8fN7Nvt+Jd6fG5m9lEzO2Fmj5rZ1U3EOYwafbvezJ7vOH63NhHnsMzsIjO718y+Z2bHzex1XZ9ne+ykWv3L9viZ2Ws64n7EzF4wsw90tYly/Op8Y1HJbpJ0XNLL+3x+j7vfOMF4YniDu/d7kOHNkl7dXl4r6VD7Zy4G9U2SHnL3nROLZrz+StKX3P0dZvYLkrq+ETb7Y7de/6RMj5+7/7ukK6XqolHSDyV9rqtZlOM3tVfoZrZJ0lsl3dl0LA16m6RPeuXrki4ys0uaDmramdnLJb1e1Xf5yt3/z91/3NUs22NXs3+l2CHp++7e/VR8lOM3tQld0m2SPijp3IA2b2//c+heM9s8objGySX9i5kdM7O9PT6/VNKTHe9PtdflYL2+SdLrzOxbZvZPZnbZJIMb0S9LWpb0N+0hwTvN7KVdbXI+dnX6J+V7/DrtlvT3PdZHOX5TmdDNbKekp9392IBmX5A05+7bJN0v6a6JBDde17n71ar+efc+M3t91+e9vis2lzrW9fr2sKo5L66Q9DFJn590gCOYkXS1pEPufpWk/5X0J11tcj52dfqX8/GTJLWHknZJ+odeH/dYN/Lxm8qELuk6SbvM7HFJn5L0RjO7u7OBuz/r7j9vv/24pO2TDXF07v5U++fTqsbwrulqckpS5788Nkl6ajLRjWa9vrn7C+7+k/br+yRdYGYXTzzQ4ZySdMrdj7bf36sqAXa3yfLYqUb/Mj9+K94s6WF3/68en0U5flOZ0N39Znff5O5zqv5J9BV3f1dnm67xrF2qbp5mw8xeamYvW3kt6dclfaer2RFJ727fcb9W0vPu/qMJhxqsTt/M7JVmZu3X16g615+ddKzDcPf/lPSkmb2mvWqHpO92Ncvy2En1+pfz8evwm+o93CJFOn7TXuWyipkdkLTk7kckvd/Mdkk6I+k5SXuajG0IvyTpc+2/iRlJf+fuXzKz35Mkd79D0n2S3iLphKTTkt7TUKyh6vTtHZL2mdkZST+VtNvzeiz69yUdbv+z/QeS3lPIsVuxXv+yPn5mtlHSr0n63Y510Y8fj/4DQCGmcsgFAEpEQgeAQpDQAaAQJHQAKAQJHQAKQUIHgEKQ0AGgEP8P3x3UprO23Z8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(X[y==0,0],X[y==0,1], color='red')\n",
    "plt.scatter(X[y==1,0],X[y==1,1], color='blue')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用逻辑回归"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'LogisticRegression' object has no attribute '_sigmoid'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-15-5d184a914da5>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[0mlog_reg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mLogisticRegression\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mlog_reg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-3-12f471f854d2>\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, X_train, y_train, eta, n_iters)\u001b[0m\n\u001b[0;32m     41\u001b[0m         \u001b[0mX_b\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mones\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_train\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     42\u001b[0m         \u001b[0minitial_theta\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_b\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 43\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_theta\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgradient_descent\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_b\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minitial_theta\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meta\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     44\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minterception_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_theta\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     45\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcoef_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_theta\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-3-12f471f854d2>\u001b[0m in \u001b[0;36mgradient_descent\u001b[1;34m(X_b, y, initial_theta, eta, n_iters, epsilon)\u001b[0m\n\u001b[0;32m     31\u001b[0m             \u001b[0mi_iter\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     32\u001b[0m             \u001b[1;32mwhile\u001b[0m \u001b[0mi_iter\u001b[0m \u001b[1;33m<\u001b[0m \u001b[0mn_iters\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 33\u001b[1;33m                 \u001b[0mgradient\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdJ\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_b\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     34\u001b[0m                 \u001b[0mlast_theta\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtheta\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     35\u001b[0m                 \u001b[0mtheta\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtheta\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0meta\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-3-12f471f854d2>\u001b[0m in \u001b[0;36mdJ\u001b[1;34m(theta, X_b, y)\u001b[0m\n\u001b[0;32m     25\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     26\u001b[0m         \u001b[1;32mdef\u001b[0m \u001b[0mdJ\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_b\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 27\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mX_b\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mT\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sigmoid\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_b\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m-\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_b\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     28\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     29\u001b[0m         \u001b[1;32mdef\u001b[0m \u001b[0mgradient_descent\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_b\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minitial_theta\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meta\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn_iters\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1e4\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mepsilon\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1e-8\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAttributeError\u001b[0m: 'LogisticRegression' object has no attribute '_sigmoid'"
     ]
    }
   ],
   "source": [
    "log_reg = LogisticRegression()\n",
    "log_reg.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
